{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rNdWfPXCjTjY"
   },
   "source": [
    "## Tensorflow使用feature_column处理心脏病数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 演示流程：\n",
    "\n",
    "1. 用Pandas导入 CSV 文件。\n",
    "2. 用tf.dataset读取数据，进行分批（batch）、随机排序（shuffle）处理\n",
    "3. 用feature_column将 CSV 中的列映射到用于训练模型的特征\n",
    "4. 用 Keras 构建，训练并评估模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "K1y4OHpGgss7"
   },
   "source": [
    "### 心脏病数据集\n",
    "\n",
    "有几百行数据，每行描述了一个病人（patient），每列描述了一个属性（attribute）。\n",
    "\n",
    "我们将使用这些信息来预测一位病人是否患有心脏病，这是在该数据集上的二分类任务。\n",
    "\n",
    "数据集描述如下，注意里面有数字Numerical，也有分类Categorical：\n",
    "\n",
    ">列| 描述| 特征类型 | 数据类型\n",
    ">------------|--------------------|----------------------|-----------------\n",
    ">Age | 年龄以年为单位 | Numerical | integer\n",
    ">Sex | （1 = 男；0 = 女） | Categorical | integer\n",
    ">CP | 胸痛类型（0，1，2，3，4）| Categorical | integer\n",
    ">Trestbpd | 静息血压（入院时，以mm Hg计） | Numerical | integer\n",
    ">Chol | 血清胆固醇（mg/dl） | Numerical | integer\n",
    ">FBS |（空腹血糖> 120 mg/dl）（1 = true；0 = false）| Categorical | integer\n",
    ">RestECG | 静息心电图结果（0，1，2）| Categorical | integer\n",
    ">Thalach | 达到的最大心率 | Numerical | integer\n",
    ">Exang | 运动诱发心绞痛（1 =是；0 =否）| Categorical | integer\n",
    ">Oldpeak | 与休息时相比由运动引起的 ST 节段下降|Numerical | integer\n",
    ">Slope | 在运动高峰 ST 段的斜率 | Numerical | float\n",
    ">CA | 荧光透视法染色的大血管动脉（0-3）的数量 | Numerical | integer\n",
    ">Thal | 3 =正常；6 =固定缺陷；7 =可逆缺陷 | Categorical | string\n",
    ">Target | 心脏病诊断（1 = true；0 = false） | Classification | integer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VxyBFc_kKazA"
   },
   "source": [
    "### 1. 导入库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9dEreb4QKizj"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import feature_column\n",
    "from tensorflow.keras import layers\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KCEhSZcULZ9n"
   },
   "source": [
    "### 2. 使用 Pandas读取CSV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "REZ57BXCLdfG"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sex</th>\n",
       "      <th>cp</th>\n",
       "      <th>trestbps</th>\n",
       "      <th>chol</th>\n",
       "      <th>fbs</th>\n",
       "      <th>restecg</th>\n",
       "      <th>thalach</th>\n",
       "      <th>exang</th>\n",
       "      <th>oldpeak</th>\n",
       "      <th>slope</th>\n",
       "      <th>ca</th>\n",
       "      <th>thal</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>63</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>145</td>\n",
       "      <td>233</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>150</td>\n",
       "      <td>0</td>\n",
       "      <td>2.3</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>fixed</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>67</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>160</td>\n",
       "      <td>286</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>108</td>\n",
       "      <td>1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>normal</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>67</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>120</td>\n",
       "      <td>229</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>129</td>\n",
       "      <td>1</td>\n",
       "      <td>2.6</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>reversible</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>37</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>130</td>\n",
       "      <td>250</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>187</td>\n",
       "      <td>0</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>normal</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>41</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>130</td>\n",
       "      <td>204</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>172</td>\n",
       "      <td>0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>normal</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  sex  cp  trestbps  chol  fbs  restecg  thalach  exang  oldpeak  slope  \\\n",
       "0   63    1   1       145   233    1        2      150      0      2.3      3   \n",
       "1   67    1   4       160   286    0        2      108      1      1.5      2   \n",
       "2   67    1   4       120   229    0        2      129      1      2.6      2   \n",
       "3   37    1   3       130   250    0        0      187      0      3.5      3   \n",
       "4   41    0   2       130   204    0        2      172      0      1.4      1   \n",
       "\n",
       "   ca        thal  target  \n",
       "0   0       fixed       0  \n",
       "1   3      normal       1  \n",
       "2   2  reversible       0  \n",
       "3   0      normal       0  \n",
       "4   0      normal       0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"./datas/heart/heart.csv\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "u0zhLtQqMPem"
   },
   "source": [
    "### 3. 将 dataframe 拆分为训练、验证和测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YEOpw7LhMYsI"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "193 训练集数目\n",
      "49 验证集数目\n",
      "61 测试集数目\n"
     ]
    }
   ],
   "source": [
    "train, test = train_test_split(df, test_size=0.2)\n",
    "train, val = train_test_split(train, test_size=0.2)\n",
    "print(len(train), '训练集数目')\n",
    "print(len(val), '验证集数目')\n",
    "print(len(test), '测试集数目')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "84ef46LXMfvu"
   },
   "source": [
    "### 4. 用tf.data.Dataset读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NkcaMYP-MsRe"
   },
   "outputs": [],
   "source": [
    "def df_to_dataset(df, shuffle=True, batch_size=32):\n",
    "    \"\"\"便捷函数，将pandas的df转换成dataset\"\"\"\n",
    "    labels = df.pop('target')\n",
    "    ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))\n",
    "    if shuffle:\n",
    "        ds = ds.shuffle(buffer_size=len(df))\n",
    "    ds = ds.batch(batch_size)\n",
    "    return ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CXbbXkJvMy34"
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "train_ds = df_to_dataset(train, shuffle=True, batch_size=batch_size)\n",
    "val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)\n",
    "test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CSBo3dUVNFc9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'age': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([65, 45, 46, 59, 71, 67, 67, 48, 65, 47, 39, 55, 65, 57, 67, 29, 49,\n",
      "       46, 62, 44, 56, 62, 54, 64, 59, 44, 45, 57, 50, 45, 56, 46])>, 'sex': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\n",
      "       1, 1, 1, 1, 0, 1, 1, 1, 0, 0])>, 'cp': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([3, 4, 2, 1, 2, 3, 4, 2, 4, 3, 3, 4, 4, 4, 4, 2, 4, 4, 4, 4, 2, 3,\n",
      "       4, 4, 0, 2, 2, 2, 4, 4, 4, 4])>, 'trestbps': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([160, 138, 105, 178, 160, 115, 125, 110, 110, 108,  94, 180, 135,\n",
      "       110, 160, 130, 130, 140, 124, 120, 140, 130, 122, 120, 164, 130,\n",
      "       130, 124, 150, 142, 200, 138])>, 'chol': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([360, 236, 204, 270, 302, 564, 254, 229, 248, 243, 199, 327, 254,\n",
      "       201, 286, 204, 269, 311, 209, 169, 294, 263, 286, 246, 176, 219,\n",
      "       234, 261, 243, 309, 288, 243])>, 'fbs': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "       0, 0, 1, 0, 0, 0, 0, 0, 1, 0])>, 'restecg': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([2, 2, 0, 2, 0, 2, 0, 0, 2, 0, 0, 1, 2, 0, 2, 2, 0, 0, 0, 0, 2, 0,\n",
      "       2, 2, 0, 2, 2, 0, 2, 2, 2, 2])>, 'thalach': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([151, 152, 172, 145, 162, 160, 163, 168, 158, 152, 179, 117, 127,\n",
      "       126, 108, 202, 163, 120, 163, 144, 153,  97, 116,  96,  90, 188,\n",
      "       175, 141, 128, 147, 133, 152])>, 'exang': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0,\n",
      "       1, 1, 0, 0, 0, 0, 0, 1, 1, 1])>, 'oldpeak': <tf.Tensor: shape=(32,), dtype=float64, numpy=\n",
      "array([0.8, 0.2, 0. , 4.2, 0.4, 1.6, 0.2, 1. , 0.6, 0. , 0. , 3.4, 2.8,\n",
      "       1.5, 1.5, 0. , 0. , 1.8, 0. , 2.8, 1.3, 1.2, 3.2, 2.2, 1. , 0. ,\n",
      "       0.6, 0.3, 2.6, 0. , 4. , 0. ])>, 'slope': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([1, 2, 1, 3, 1, 2, 2, 3, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 1, 3, 2, 2,\n",
      "       2, 3, 1, 1, 2, 1, 2, 2, 3, 2])>, 'ca': <tf.Tensor: shape=(32,), dtype=int64, numpy=\n",
      "array([0, 0, 0, 0, 2, 0, 2, 0, 2, 0, 0, 0, 1, 0, 3, 0, 0, 2, 0, 0, 0, 1,\n",
      "       2, 1, 2, 0, 0, 0, 0, 3, 2, 0])>, 'thal': <tf.Tensor: shape=(32,), dtype=string, numpy=\n",
      "array([b'normal', b'normal', b'normal', b'reversible', b'normal',\n",
      "       b'reversible', b'reversible', b'reversible', b'fixed', b'normal',\n",
      "       b'normal', b'normal', b'reversible', b'fixed', b'normal',\n",
      "       b'normal', b'normal', b'reversible', b'normal', b'fixed',\n",
      "       b'normal', b'reversible', b'normal', b'normal', b'1', b'normal',\n",
      "       b'normal', b'reversible', b'reversible', b'reversible',\n",
      "       b'reversible', b'normal'], dtype=object)>}\n",
      "\n",
      "tf.Tensor(\n",
      "[65 45 46 59 71 67 67 48 65 47 39 55 65 57 67 29 49 46 62 44 56 62 54 64\n",
      " 59 44 45 57 50 45 56 46], shape=(32,), dtype=int64)\n",
      "\n",
      "tf.Tensor([0 0 0 0 0 0 1 0 0 0 0 1 1 0 1 0 0 1 0 1 0 1 1 1 0 0 0 0 1 1 1 0], shape=(32,), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "# 测试下数据集\n",
    "for feature_batch, label_batch in train_ds.take(1):\n",
    "    print(feature_batch)\n",
    "    print()\n",
    "    print(feature_batch[\"age\"])\n",
    "    print()\n",
    "    print(label_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ypkI9zx6Rj1q"
   },
   "source": [
    "### 5. feature_column特征处理\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4PlLY7fORuzA"
   },
   "outputs": [],
   "source": [
    "feature_columns = []\n",
    "\n",
    "# 数值列\n",
    "for header in ['age', 'trestbps', 'chol', 'thalach', 'oldpeak', 'slope', 'ca']:\n",
    "    feature_columns.append(feature_column.numeric_column(header))\n",
    "\n",
    "# 分桶列\n",
    "age = feature_column.numeric_column(\"age\")\n",
    "age_buckets = feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])\n",
    "feature_columns.append(age_buckets)\n",
    "\n",
    "# 分类列\n",
    "thal = feature_column.categorical_column_with_vocabulary_list(\n",
    "    'thal', ['fixed', 'normal', 'reversible'])\n",
    "thal_one_hot = feature_column.indicator_column(thal)\n",
    "feature_columns.append(thal_one_hot)\n",
    "\n",
    "# 嵌入列\n",
    "thal_embedding = feature_column.embedding_column(thal, dimension=4)\n",
    "feature_columns.append(thal_embedding)\n",
    "\n",
    "# 组合列\n",
    "crossed_feature = feature_column.crossed_column([age_buckets, thal], hash_bucket_size=100)\n",
    "crossed_feature = feature_column.indicator_column(crossed_feature)\n",
    "feature_columns.append(crossed_feature)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bBx4Xu0eTXWq"
   },
   "source": [
    "### 6. 创建，编译和训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_YJPPb3xTPeZ"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "WARNING:tensorflow:Layer dense_features is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n",
      "7/7 [==============================] - 0s 42ms/step - loss: 6.4846 - accuracy: 0.7150 - val_loss: 2.8021 - val_accuracy: 0.2857\n",
      "Epoch 2/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 1.5845 - accuracy: 0.5596 - val_loss: 2.5465 - val_accuracy: 0.7143\n",
      "Epoch 3/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 2.0645 - accuracy: 0.7150 - val_loss: 0.7416 - val_accuracy: 0.6327\n",
      "Epoch 4/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.6792 - accuracy: 0.7150 - val_loss: 0.6082 - val_accuracy: 0.7347\n",
      "Epoch 5/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.6407 - accuracy: 0.7254 - val_loss: 0.6311 - val_accuracy: 0.6735\n",
      "Epoch 6/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.5915 - accuracy: 0.7720 - val_loss: 0.5975 - val_accuracy: 0.7959\n",
      "Epoch 7/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.7065 - accuracy: 0.6425 - val_loss: 0.9846 - val_accuracy: 0.7143\n",
      "Epoch 8/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.8068 - accuracy: 0.6943 - val_loss: 0.9719 - val_accuracy: 0.5918\n",
      "Epoch 9/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.9149 - accuracy: 0.6269 - val_loss: 1.4216 - val_accuracy: 0.7143\n",
      "Epoch 10/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 1.0009 - accuracy: 0.7565 - val_loss: 0.6289 - val_accuracy: 0.7143\n",
      "Epoch 11/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.6719 - accuracy: 0.7047 - val_loss: 0.8216 - val_accuracy: 0.6939\n",
      "Epoch 12/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.6143 - accuracy: 0.7513 - val_loss: 0.9296 - val_accuracy: 0.6327\n",
      "Epoch 13/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.9585 - accuracy: 0.6321 - val_loss: 1.2613 - val_accuracy: 0.7143\n",
      "Epoch 14/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.8586 - accuracy: 0.7409 - val_loss: 0.4195 - val_accuracy: 0.8367\n",
      "Epoch 15/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.5431 - accuracy: 0.7720 - val_loss: 0.4846 - val_accuracy: 0.6939\n",
      "Epoch 16/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.4645 - accuracy: 0.8083 - val_loss: 0.4018 - val_accuracy: 0.7551\n",
      "Epoch 17/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.4154 - accuracy: 0.7979 - val_loss: 0.3967 - val_accuracy: 0.7551\n",
      "Epoch 18/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.4007 - accuracy: 0.7979 - val_loss: 0.3876 - val_accuracy: 0.7551\n",
      "Epoch 19/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.4068 - accuracy: 0.7927 - val_loss: 0.4544 - val_accuracy: 0.7143\n",
      "Epoch 20/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.4786 - accuracy: 0.7772 - val_loss: 0.5740 - val_accuracy: 0.7347\n",
      "Epoch 21/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.5977 - accuracy: 0.7306 - val_loss: 0.4511 - val_accuracy: 0.7347\n",
      "Epoch 22/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.5956 - accuracy: 0.7513 - val_loss: 0.5816 - val_accuracy: 0.7143\n",
      "Epoch 23/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.6514 - accuracy: 0.7565 - val_loss: 0.5628 - val_accuracy: 0.7551\n",
      "Epoch 24/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.5310 - accuracy: 0.7720 - val_loss: 0.3644 - val_accuracy: 0.8367\n",
      "Epoch 25/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.6400 - accuracy: 0.7150 - val_loss: 0.6804 - val_accuracy: 0.7347\n",
      "Epoch 26/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.6860 - accuracy: 0.7306 - val_loss: 0.8425 - val_accuracy: 0.6735\n",
      "Epoch 27/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 1.0230 - accuracy: 0.5751 - val_loss: 1.2422 - val_accuracy: 0.7143\n",
      "Epoch 28/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 1.2425 - accuracy: 0.7306 - val_loss: 0.3380 - val_accuracy: 0.8571\n",
      "Epoch 29/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.3581 - accuracy: 0.8135 - val_loss: 0.3282 - val_accuracy: 0.8367\n",
      "Epoch 30/30\n",
      "7/7 [==============================] - 0s 3ms/step - loss: 0.3520 - accuracy: 0.8238 - val_loss: 0.3332 - val_accuracy: 0.8571\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f0e6047c110>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = tf.keras.Sequential([\n",
    "    layers.DenseFeatures(feature_columns),\n",
    "    layers.Dense(64, activation='relu'),\n",
    "    layers.Dense(32, activation='relu'),\n",
    "    layers.Dense(1, activation='sigmoid')\n",
    "])\n",
    "\n",
    "model.compile(optimizer='adam',\n",
    "              loss='binary_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "model.fit(train_ds,\n",
    "          validation_data=val_ds,\n",
    "          epochs=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "GnFmMOW0Tcaa"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2/2 [==============================] - 0s 1ms/step - loss: 0.6322 - accuracy: 0.7377\n",
      "Accuracy 0.7377049326896667\n"
     ]
    }
   ],
   "source": [
    "loss, accuracy = model.evaluate(test_ds)\n",
    "print(\"Accuracy\", accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "feature_columns.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
